# core/formalization/rl/state.py
import numpy as np
from typing import Dict, List
import nltk
from nltk.sentiment import SentimentIntensityAnalyzer
import torch

from core.formalization.action_space import ActionType, get_action_index
from utils.logger import Logger
from llm.llm_wrapper import LLMWrapper
from llm.auxiliary import Auxiliary

STATE_DIM = 16

class State:

    def __init__(self, logger: Logger, llm: LLMWrapper, auxiliary: Auxiliary, actions: List, config={}):
        self.logger = logger
        self.auxiliary = auxiliary
        self.llm = llm
        self.config = config
        self.actions = actions

        try:
            nltk.data.find('vader_lexicon')
        except Exception:
            nltk.download('vader_lexicon')
        
        self.sentiment_analyzer = SentimentIntensityAnalyzer()

    def compute_state_vector(self, 
                           original_query: str,
                           cur_query: str,
                           interaction_history: Dict) -> np.ndarray:
        
        if 'response' not in interaction_history:
            raise ValueError("No response in interaction history")

        try:
            self.logger.info("State start to compute state vec")
            original_semantic = self.auxiliary.api_embedding(original_query)
            current_semantic = self.auxiliary.api_embedding(cur_query)
            self.logger.info("Success to get original and current query embedding")
            emb1 = torch.tensor(original_semantic)
            emb2 = torch.tensor(current_semantic)
            semantic_sim = self.auxiliary.embedding_similarity(emb1, emb2)
            self.logger.info(f"Original and current query embedding similarity: [{semantic_sim}]")

            action_history = interaction_history.get('action_history', [])
            action_used = self._compute_action_used_vector(action_history)

            sentiment_features = self._compute_nltk_sentiment(interaction_history['response'])

            step_count = interaction_history.get('step_count', 0)
            avg_res_length = interaction_history.get('avg_response_length', 0)
            n_sensitive_words = interaction_history.get('n_sensitive_words', 0)
            input_length_ratio = len(cur_query) / max(len(original_query), 1)

            self.logger.info(f"State start calculate the perplexity of [{cur_query}]")
            
            state_vector = np.concatenate([
                [semantic_sim],
                action_used,
                sentiment_features,
                [step_count],
                [avg_res_length],
                [input_length_ratio],
                [n_sensitive_words]
            ])

            return state_vector.astype(np.float32)

        except Exception as e:
            self.logger.log_exception(e)
            return np.zeros(STATE_DIM, dtype=np.float32)
    
    def _compute_action_used_vector(self, action_history: List[Dict]) -> np.ndarray:
        action_used = np.zeros(len(self.actions), dtype=np.float32)
        
        for action in action_history:
            action_type: ActionType = action['action']
            action_index = get_action_index(action_type)
            action_used[action_index] = 1.0
        
        return action_used
    
    def _compute_nltk_sentiment(self, response: str) -> np.ndarray:
        try:
            if not response:
                return np.zeros(4, dtype=np.float32)
            
            scores = self.sentiment_analyzer.polarity_scores(response)
            sentiment_features = np.array([
                scores['pos'],
                scores['neg'],
                scores['neu'],
                scores['compound']
            ], dtype=np.float32)
            
            return sentiment_features
        except Exception as e:
            self.logger.log_exception(e)
            return np.zeros(4, dtype=np.float32)
